package nl.cypherpunk.statefuzzer;

import de.learnlib.algorithms.lstar.mealy.ExtensibleLStarMealyBuilder;
import de.learnlib.api.SUL;
import de.learnlib.api.algorithm.LearningAlgorithm;
import de.learnlib.api.logging.LearnLogger;
import de.learnlib.api.oracle.EquivalenceOracle;
import de.learnlib.api.query.DefaultQuery;
import de.learnlib.filter.cache.mealy.MealyCacheOracle;
import de.learnlib.filter.cache.mealy.MealyCaches;
import de.learnlib.filter.statistic.Counter;
import de.learnlib.filter.statistic.oracle.MealyCounterOracle;
import de.learnlib.oracle.membership.SULOracle;
import de.learnlib.util.statistics.SimpleProfiler;
import net.automatalib.automata.transducers.MealyMachine;
import net.automatalib.serialization.dot.GraphDOT;
import net.automatalib.words.Alphabet;
import net.automatalib.words.Word;
import nl.cypherpunk.statefuzzer.openvpn.VPNConfig;
import nl.cypherpunk.statefuzzer.openvpn.VPNSUL;
import nl.cypherpunk.statefuzzer.rtsp.RTSPConfig;
import nl.cypherpunk.statefuzzer.rtsp.RTSPSUL;
import nl.cypherpunk.statefuzzer.smtp.SMTPConfig;
import nl.cypherpunk.statefuzzer.smtp.SMTPSUL;
import nl.cypherpunk.statefuzzer.tls.TLSConfig;
import nl.cypherpunk.statefuzzer.tls.TLSSUL;

import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

public class ProLearner {
    LearningConfig config;
    Alphabet<String> alphabet;
    boolean combine_query = false;
    SUL<String, String> sul;
    SULOracle<String, String> memOracle;
    LogOracle.MealyLogOracle<String, String> logMemOracle;
    MealyCounterOracle<String, String> statsMemOracle;
    MealyCacheOracle<String, String> cachedMemOracle;
    MealyCounterOracle<String, String> statsCachedMemOracle;
    LearningAlgorithm<MealyMachine<?, String, ?, String>, String, Word<String>> learningAlgorithm;

    SULOracle<String, String> eqOracle;
    LogOracle.MealyLogOracle<String, String> logEqOracle;
    MealyCounterOracle<String, String> statsEqOracle;
    MealyCacheOracle<String, String> cachedEqOracle;
    MealyCounterOracle<String, String> statsCachedEqOracle;
    EquivalenceOracle<MealyMachine<?, String, ?, String>, String, Word<String>> equivalenceAlgorithm;
    public ProLearner(LearningConfig config) throws Exception {
        this.config = config;

        // Create output directory if it doesn't exist
        Path path = Paths.get(config.output_dir);
        if(Files.notExists(path)) {
            Files.createDirectories(path);
        }

        LearnLogger log = LearnLogger.getLogger(Learner.class.getSimpleName());

        // Check the type of learning we want to do and create corresponding configuration and SUL
        if(config.type == LearningConfig.TYPE_TLS) {
            log.logPhase( "Using TLS SUL");

            // Create the TLS SUL
            sul = new TLSSUL(new TLSConfig(config));
            alphabet = ((TLSSUL)sul).getAlphabet();
        }
        else if(config.type == LearningConfig.TYPE_VPN) {
            log.logPhase("Using VPN SUL");

            // Create the VPN SUL
            sul = new VPNSUL(new VPNConfig(config));
            alphabet = ((VPNSUL)sul).getAlphabet();
        }
        else if(config.type == LearningConfig.TYPE_SMTP) {
            log.logPhase("Using SMTP SUL");
            String killall = "sudo killall exim";
            ProcessBuilder pbkill = new ProcessBuilder(killall.split(" "));
            Process tmp = pbkill.start();
            tmp.waitFor();
            killall = "sudo rm /home/exim*";
            pbkill = new ProcessBuilder(killall.split(" "));
            tmp = pbkill.start();
            tmp.waitFor();
            // Create the SMTP SUL
            sul = new SMTPSUL(new SMTPConfig(config));
            alphabet = ((SMTPSUL)sul).getAlphabet();
        }
        else if(config.type == LearningConfig.TYPE_RTSP) {
            log.logPhase("Using RTSP SUL");

            // Create the SMTP SUL
            sul = new RTSPSUL(new RTSPConfig(config));
            alphabet = ((RTSPSUL)sul).getAlphabet();
        }
    }

    public void learn() throws IOException, InterruptedException {
        LearnLogger log = LearnLogger.getLogger(Learner.class.getSimpleName());

		LogOracle.MealyLogOracle<String, String> logEqOracle = new LogOracle.MealyLogOracle<String, String>(this.sul, LearnLogger.getLogger("queries"));
		statsEqOracle = new MealyCounterOracle<String, String>(logEqOracle,"MembershipQuery");

		cachedMemOracle = MealyCaches.createTreeCache(this.alphabet, statsEqOracle);
		equivalenceAlgorithm = new ModifiedWMethodEQOracle.MealyModifiedWMethodEQOracle<String, String>(config.max_depth,cachedMemOracle);
		learningAlgorithm
				= new ExtensibleLStarMealyBuilder<String, String>().withAlphabet(alphabet).withOracle(cachedMemOracle).create();

        log.logPhase( "Using learning algorithm " + learningAlgorithm.getClass().getSimpleName());
        log.logPhase( "Using equivalence algorithm " + equivalenceAlgorithm.getClass().getSimpleName());

        log.logPhase( "Starting learning");

        SimpleProfiler.start("Total time");

        boolean learning = true;
        Counter round = new Counter("Rounds", "");

        round.increment();
        System.out.println("Starting round " + round.getCount());
        SimpleProfiler.start("Learning");
        learningAlgorithm.startLearning();
        SimpleProfiler.stop("Learning");

        MealyMachine<?, String, ?, String> hypothesis = learningAlgorithm.getHypothesisModel();
        DefaultQuery<String, Word<String>> tmpCE = null;
        while(learning) {
            // Write outputs
            writeDotModel(hypothesis, alphabet, config.output_dir + "/hypothesis_" + round.getCount() + ".dot");

            // Search counter-example
            System.out.println("Searching for counter-example");
            SimpleProfiler.start("Searching for counter-example");
            DefaultQuery<String, Word<String>> counterExample = equivalenceAlgorithm.findCounterExample(hypothesis, alphabet);
            SimpleProfiler.stop("Searching for counter-example");

            if(counterExample == null) {
                // No counter-example found, so done learning
                learning = false;

                // Write outputs
                writeDotModel(hypothesis, alphabet, config.output_dir + "/learnedModel.dot");
                //writeAutModel(hypothesis, alphabet, config.output_dir + "/learnedModel.aut");
            }
            else {
                if(tmpCE==null)
                    tmpCE = counterExample;
                else {
                    if (tmpCE.toString().equals(counterExample.toString())) {
                        // No counter-example found, so done learning
                        learning = false;

                        // Write outputs
                        writeDotModel(hypothesis, alphabet, config.output_dir + "/learnedModel.dot");
                        //writeAutModel(hypothesis, alphabet, config.output_dir + "/learnedModel.aut");
                    }
                    else
                        tmpCE = counterExample;
                }
                // Counter example found, update hypothesis and continue learning
                System.out.println("Counter-example found: " + counterExample.toString());
                //TODO Add more logging
                round.increment();
                System.out.println("Starting round " + round.getCount());

                SimpleProfiler.start("Learning");
                learningAlgorithm.refineHypothesis(counterExample);
                SimpleProfiler.stop("Learning");

                hypothesis = learningAlgorithm.getHypothesisModel();
            }
        }

        SimpleProfiler.stop("Total time");

        // Output statistics
        System.out.println("-------------------------------------------------------");
        System.out.println( SimpleProfiler.getResults());
        System.out.println(round.getSummary());
        System.out.println(statsCachedMemOracle.getStatisticalData().getSummary());
        System.out.println( "States in final hypothesis: " + hypothesis.size());
		/*
		if(config.type == LearningConfig.TYPE_TLS) {
			TLSSUL tmp = (TLSSUL)sul;
			tmp.exit();
			String killssl = "sudo killall " + config.output_dir;

			ProcessBuilder pbkill = new ProcessBuilder(killssl.split(" "));
			Process killprocess = pbkill.start();
			killprocess.waitFor();
		}
		 */
    }

    public static void writeDotModel(MealyMachine<?, String, ?, String> model, Alphabet<String> alphabet, String filename) throws IOException, InterruptedException {
        // Write output to dot-file
        File dotFile = new File(filename);
        PrintStream psDotFile = new PrintStream(dotFile);
        GraphDOT.write(model, alphabet, psDotFile);
        psDotFile.close();

        //TODO Check if dot is available

        // Convert .dot to .pdf
        Runtime.getRuntime().exec("dot -Tpdf -O " + filename);
    }

    public static void main(String[] args) throws Exception {
        if(args.length < 1) {
            System.err.println("Invalid number of parameters");
            System.exit(-1);
        }

        LearningConfig config = new LearningConfig(args[0]);

        Learner learner = new Learner(config);
        learner.learn();

        System.exit(0);
    }
}
